# Copyright (c) Microsoft Corporation

import torch.nn as nn
from . import networks


class Mapping_Model_with_mask(nn.Module):
    def __init__(
        self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
    ):
        super(Mapping_Model_with_mask, self).__init__()

        norm_layer = networks.get_norm_layer(norm_type=norm)
        activation = nn.ReLU(True)
        model = []

        tmp_nc = 64
        n_up = 4

        for i in range(n_up):
            ic = min(tmp_nc * (2**i), mc)
            oc = min(tmp_nc * (2 ** (i + 1)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]

        self.before_NL = nn.Sequential(*model)

        if opt.NL_res:
            self.NL = networks.NonLocalBlock2D_with_mask_Res(
                mc,
                mc,
                opt.NL_fusion_method,
                opt.correlation_renormalize,
                opt.softmax_temperature,
                opt.use_self,
                opt.cosin_similarity,
            )

            print("using NL + Res...")

        model = []
        for i in range(n_blocks):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        for i in range(n_up - 1):
            ic = min(64 * (2 ** (4 - i)), mc)
            oc = min(64 * (2 ** (3 - i)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
        if opt.feat_dim > 0 and opt.feat_dim < 64:
            model += [
                norm_layer(tmp_nc),
                activation,
                nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
            ]
        # model += [nn.Conv2d(64, 1, 1, 1, 0)]
        self.after_NL = nn.Sequential(*model)

    def forward(self, input, mask):
        x1 = self.before_NL(input)
        del input
        x2 = self.NL(x1, mask)
        del x1, mask
        x3 = self.after_NL(x2)
        del x2

        return x3


class Mapping_Model_with_mask_2(nn.Module):  ## Multi-Scale Patch Attention
    def __init__(
        self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
    ):
        super(Mapping_Model_with_mask_2, self).__init__()

        norm_layer = networks.get_norm_layer(norm_type=norm)
        activation = nn.ReLU(True)
        model = []

        tmp_nc = 64
        n_up = 4

        for i in range(n_up):
            ic = min(tmp_nc * (2**i), mc)
            oc = min(tmp_nc * (2 ** (i + 1)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]

        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        print("using multi-scale patch attention, conv combine + mask input...")

        self.before_NL = nn.Sequential(*model)

        if opt.mapping_exp == 1:
            self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        self.res_block_1 = nn.Sequential(*model)

        if opt.mapping_exp == 1:
            self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        self.res_block_2 = nn.Sequential(*model)

        if opt.mapping_exp == 1:
            self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2)
        # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)

        model = []
        for i in range(2):
            model += [
                networks.ResnetBlock(
                    mc,
                    padding_type=padding_type,
                    activation=activation,
                    norm_layer=norm_layer,
                    opt=opt,
                    dilation=opt.mapping_net_dilation,
                )
            ]

        for i in range(n_up - 1):
            ic = min(64 * (2 ** (4 - i)), mc)
            oc = min(64 * (2 ** (3 - i)), mc)
            model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
        model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
        if opt.feat_dim > 0 and opt.feat_dim < 64:
            model += [
                norm_layer(tmp_nc),
                activation,
                nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
            ]
        # model += [nn.Conv2d(64, 1, 1, 1, 0)]
        self.after_NL = nn.Sequential(*model)

    def forward(self, input, mask):
        x1 = self.before_NL(input)
        x2 = self.NL_scale_1(x1, mask)
        x3 = self.res_block_1(x2)
        x4 = self.NL_scale_2(x3, mask)
        x5 = self.res_block_2(x4)
        x6 = self.NL_scale_3(x5, mask)
        x7 = self.after_NL(x6)
        return x7

    def inference_forward(self, input, mask):
        x1 = self.before_NL(input)
        del input
        x2 = self.NL_scale_1.inference_forward(x1, mask)
        del x1
        x3 = self.res_block_1(x2)
        del x2
        x4 = self.NL_scale_2.inference_forward(x3, mask)
        del x3
        x5 = self.res_block_2(x4)
        del x4
        x6 = self.NL_scale_3.inference_forward(x5, mask)
        del x5
        x7 = self.after_NL(x6)
        del x6
        return x7
